import time
import os.path
import torch
import argparse
from pre_process import *
from utils import *
from exp import Experiments
from utils import print_local_time, set_seed

parser = argparse.ArgumentParser()

parser.add_argument("--dataset", type=str, default="environment", help="dataset")
## Model parameters
parser.add_argument("--pre_train", type=str, default="bert", help="Pre_trained model")
parser.add_argument(
    "--hidden", type=int, default=64, help="dimension of hidden layers in MLP"
)
parser.add_argument(
    "--embed_size", type=int, default=6, help="dimension of bubble embeddings"
)
parser.add_argument("--dropout", type=float, default=0.05, help="dropout")
parser.add_argument("--margin", type=float, default=0.05, help="margin for containing")
parser.add_argument(
    "--epsilon", type=float, default=-0.1, help="margin for disjointness"
)
parser.add_argument("--phi", type=float, default=0.5, help="minimum volumn of bubble")
parser.add_argument("--alpha", type=float, default=0.25, help="weight of geometric containment loss")
parser.add_argument(
    "--beta", type=float, default=0.25, help="weight of geometric disjointness loss"
)
parser.add_argument(
    "--gamma", type=float, default=1.0, help="weight of regularization loss"
)
parser.add_argument("--delta", type=float, default=0.5, help="weight of prob loss")
parser.add_argument("--negsamples",type=int,default=25,help="Number of negative samples per node")

## Training hyper-parameters
parser.add_argument("--expID", type=int, default=0, help="-th of experiments")
parser.add_argument("--epochs", type=int, default=30, help="training epochs")
parser.add_argument("--batch_size", type=int, default=200, help="training batch size")
parser.add_argument(
    "--lr", type=float, default=2e-5, help="learning rate for pre-trained model"
)
parser.add_argument(
    "--lr_projection",
    type=float,
    default=1e-3,
    help="learning rate for projection layers",
)
parser.add_argument("--eps", type=float, default=1e-8,help="adamw_epsilon")
parser.add_argument("--optim", type=str, default="adamw",help="Optimizer")

## Others
parser.add_argument("--cuda", type=bool, default=True,help="use cuda for training")
parser.add_argument("--gpu_id", type=int, default=0, help="which gpu")
parser.add_argument("--seed",type=int,default=42,help="Seed for random generator")

parser.add_argument("--contrastive",type=bool,default=True,help="Contrastive Loss for centers")
parser.add_argument("--radratio",type=bool,default=True,help="Volume Ratio loss btw child and parent")
parser.add_argument("--minvol",type=float,default=0.2,help="Min vol of child wrt parent")

parser.add_argument("--theta",type=float,default=0.3,help="Weight of center distance in scoring function")

start_time = time.time()
print("Start time at : ")
print_local_time()

args = parser.parse_args()
args.cuda = True if torch.cuda.is_available() and args.cuda else False
if args.cuda:
    torch.cuda.set_device(args.gpu_id)

print(args)

set_seed(args.seed)

if not os.path.exists(os.path.join("../data/",args.dataset,"processed")):
    os.makedirs(os.path.join("../data/",args.dataset,"processed"))

if not os.path.isfile(os.path.join("../data/",args.dataset,"processed","taxonomy_data_"+str(args.expID)+"_.pkl")):
    create_data(args)

resdir = f"../result/{args.dataset}"
if not os.path.exists(resdir):
    os.makedirs(resdir)

exp = Experiments(args)

"""Train the model"""
exp.train()
exp.predict(tag="test")

print("Time used :{:.01f}s".format(time.time() - start_time))
print("End time at : ")
print_local_time()
print("************END***************")
